import os
import glob
import numpy as np
import pandas as pd
import joblib
import umap.umap_ as umap
from sklearn.manifold import MDS
import time
import plotly.express as px
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist
from tslearn.barycenters import euclidean_barycenter
from matplotlib.backends.backend_pdf import PdfPages
from constants import DIST_FILENAMES, FEATURE_NAMES, WEEK_MEASUREMENTS,FEAT_IDX,FEAT_LABEL

def get_distance_matrix(data, distance_metric, directory):
    """
    Computes the pairwise Euclidean distance matrix between data points.

    Args:
        data (numpy.ndarray): A 2D array of shape (n_samples, n_features).

    Returns:
        numpy.ndarray: A 2D array of shape (n_samples, n_samples) containing
            the pairwise Euclidean distances between data points.
    """
    filename = f"{directory}{DIST_FILENAMES[distance_metric]}"
    print (f"Checking for distance matrix file... {filename}")
    if not os.path.exists(filename):
        print (f"File not found, recalculating...")
        if distance_metric == "manhattan":
            distance_matrix = cdist(data, data, metric='cityblock')
        
        elif distance_metric == "euclidean":
            distance_matrix = cdist(data, data, metric='euclidean')

        elif distance_metric == "minkowski":
            distance_matrix = cdist(data, data, metric='minkowski', p=4)
        
        elif distance_metric == "cosine":
            distance_matrix = cdist(data, data, 'cosine')
            
        np.save(filename, distance_matrix)

    else:
        print(f"Path for distance matrix exists..loading existing..")
        distance_matrix = np.load(filename)
        
    return distance_matrix

def generate_mds_maps(directory, distance_metric):
    """
    Generates the 2D and 3D MDS maps for the specific dataset and 
    saves it in the used directory

    """
    distance_matrix_file = f"{directory}{DIST_FILENAMES[distance_metric]}"
    distance_matrix_np = np.load(distance_matrix_file)
    mdsmaps = []
    maps = [("mds2d", 2), ("mds3d", 3)]
    for m in maps:
        location = directory  + distance_metric + "_" +m[0] + ".npy"
        if os.path.exists(location):
            mds_embedding = np.load(location)
        else:
            mds_model = MDS(n_components=m[1],
                        metric=True,
                        dissimilarity='precomputed',
                        random_state=0
                    ).fit(distance_matrix_np) # choose distance matrix
            mds_embedding = mds_model.embedding_
            np.save(location, mds_embedding)
        mdsmaps.append(mds_embedding)

    return mdsmaps
    
        
def generate_umap_maps(directory, distance_metric):
    """
    Generates the 2D and 3D UMAPS maps for the specific dataset and 
    saves it in the used directory

    """
    print ("Generating UMAPS 2D and 3D")
    distance_matrix_file = f"{directory}{DIST_FILENAMES[distance_metric]}"
    distance_matrix_np = np.load(distance_matrix_file)
    maps = [("umap2d", 2), ("umap3d", 3)]
    umaps = []
    for m in maps:
        t = time.time()
        location = directory  + distance_metric + "_" + m[0] + ".npy"
        if os.path.exists(location):
            umap_embedding = np.load(location)
        else:
            n = 20 # choose n_neighbors
            mapper = umap.UMAP(n_neighbors=n,
                            min_dist=0,
                            n_components=m[1] ,
                            metric='precomputed',
                            set_op_mix_ratio=0.05,
                            transform_queue_size=10,
                            densmap=False,
                            init='spectral',
                            random_state=0
                            ).fit(distance_matrix_np) # choose distance matrix
            umap_embedding = mapper.embedding_
            np.save(location, umap_embedding)
            print (f"Saved UMAP embedding at : {location}")
        umaps.append(umap_embedding)
    return umaps

def create_data_array(directory):
    """
    Reads the .csv files in the input list, processes them, and returns a Numpy array.

    Args:
    directory (list): A path to the features stored as VIN names that have FREQUENCY 
    separated data for every VIN.

    Returns:
    np_array (numpy.ndarray): A Numpy array containing the vehicle-week data

    """
    data_raw = []
    valid_vins = []
    drop_vins = []
    valid_weeks = []
    drop_weeks = []
    scaler = joblib.load(f"{directory}scaler.save") # dataset folder
    files = sorted(glob.glob(os.path.join(directory, '*.csv'))) # list of all files from dataset folder
    for vin_file in files: # loop through VIN files
        vin_df = pd.read_csv(vin_file)
        data_features = vin_df.loc[:,FEATURE_NAMES].copy() # extract features
        data_features_norm = scaler.transform(data_features) # normalization
        data_features = data_features.values
        for w in range(0, len(data_features)-WEEK_MEASUREMENTS+1, WEEK_MEASUREMENTS): # sliding window (size = week, stride = week)
            window = data_features[w:w+WEEK_MEASUREMENTS,:]
            
            if sum(abs(window.T[2])) < 10 or  -1 in window.T[1]: # drop vehicle-weeks where total absolute change in SOC < 10
                drop_vins.append(vin_file.split('/')[-1][:-4])
                drop_weeks.append(vin_df.loc[w,'week_num'])
            else:
                valid_vins.append(vin_file.split('/')[-1][:-4])
                valid_weeks.append(vin_df.loc[w,'week_num'])
                window = data_features_norm[w:w+WEEK_MEASUREMENTS,:]
                data_raw.append(window)
    return np.array(data_raw), valid_vins, drop_vins, valid_weeks, drop_weeks, scaler

def create_non_drop_data_arr(directory):
    """
        Returns a numpy array containing all vehicle-weeks (no dropping)
    """
    files = sorted(glob.glob(os.path.join(directory, '*.csv')))
    data_raw = []
    for vin_file in files: # loop through VIN files
        vin_df = pd.read_csv(vin_file)
        data_features = vin_df.loc[:,FEATURE_NAMES].copy().values # extract features
        for w in range(0, len(data_features)-WEEK_MEASUREMENTS+1, WEEK_MEASUREMENTS): # sliding window (size = week, stride = week)
            window = data_features[w:w+WEEK_MEASUREMENTS,:]
            data_raw.append(window)
    return np.array(data_raw)

def get_train_vins(directory):
    csvfiles = sorted(glob.glob(os.path.join(directory, '*.csv'))) # list of all files from dataset folder
    train_files = csvfiles
    train_vins = [] # list of VINs in training set
    for file in train_files:
        train_vins.append(file.split('/')[-1][:-4])
    num_train_vins = len(train_files) # number of VINs in training set
    print (f"Num of VINS: {num_train_vins}")
    return train_vins

def plot_save_clusters(labels, params,
                        distance_metric_umap = "euclidean", 
                        distance_metric_mds= "euclidean") -> str:
    """
    Plots the data points colored by cluster assignment.

    Args:
        labels (numpy.ndarray): A 1D array of shape (n_samples,) containing
            the cluster assignments for each data point.
    """
    optimal = params.n_clusters
    base_dir = os.path.join(params.features_dir, params.cluster_algo)
    base_dir = f"{base_dir}/K_{params.n_clusters}_distance_{params.distance_metric}_random_{params.random_state}"
    if params.cluster_algo == "agglomerative":
        base_dir = f"{base_dir}_{params.distance_threshold}"
    if params.dr:
        base_dir  = f"{base_dir}/DR/"

    if not os.path.exists(base_dir):
        os.makedirs(base_dir)
    
    umap2d, umap3d = generate_umap_maps(params.features_dir, params.distance_metric)
    umap2d_filename = f"{base_dir}/umap2d.png"
    umap3d_filename = f"{base_dir}/umap3d.html"

    plt.figure(figsize=(5,5))
    scatter = plt.scatter(umap2d.T[0], umap2d.T[1], c=labels, s=5)
    plt.legend(*scatter.legend_elements(), bbox_to_anchor=(1,1), loc='upper left')
    plt.xlabel('')
    plt.ylabel('')
    plt.axis('square')
    plt.savefig(umap2d_filename)

    fig = px.scatter_3d(x=umap3d.T[0], y=umap3d.T[1], z=umap3d.T[2],
                        color=labels.astype(str), color_discrete_sequence=px.colors.qualitative.Dark24,
                        labels=dict(x='', y='', z=''))
    fig.update_traces(marker=dict(size=5, opacity=0.2))
    fig.update_layout(title={'text':'Clusters='+str(optimal), 'x':0.5, 'y':0.9, 'xanchor':'center', 'yanchor':'top'})
    fig.write_html(umap3d_filename)
    return base_dir

def generate_output_df(train_vins, directory, valid_train_vins, valid_train_weeks, drop_train_vins, drop_train_weeks, labels):
    # create dataframe to store labels
    valid = pd.DataFrame(zip(valid_train_vins, valid_train_weeks, labels), 
                        columns=['vin', 'week', 'cluster'])
    dropped = pd.DataFrame(zip(drop_train_vins, drop_train_weeks, [-1]*len(drop_train_vins)),
                        columns=['vin','week','cluster'])
    output = pd.concat([valid, dropped], ignore_index=True)
    output = output.sort_values(['vin','week']).reset_index(drop=True)
    output['index'] = output.index

    
    # get dates
    print (f"Len of train_vins: {len(train_vins)}")
    for vin in train_vins:
        f = pd.read_csv(f"{directory}{vin}.csv")
        for i in range(len(f)//WEEK_MEASUREMENTS):
            output.loc[(output['vin'] == vin) & (output['week'] == i), 'date'] = f['datetime'][i*WEEK_MEASUREMENTS][:-9]

    # pre-Covid = 0, Covid = 1
    for i in range(len(output)):
        if output['date'][i] > '2020-03-11':
            output.loc[i,'covid'] = 1
        else:
            output.loc[i,'covid'] = 0
    print (f"{output.shape}")
    return output

def plot_and_save_load_profiles(train_characterize, train_raw_ns, output, algorithm, base_directory, scaler, cluster_results):
    """
    median: set to True for K-Medoids results, otherwise set to False
    """
    if algorithm == "kmedoids":
        median = True
    else:
        median = False
    grouped_output = output.groupby('cluster')
    medoids = cluster_results["medoids"]
    for feature in list(FEAT_IDX.keys()):
        f = FEAT_IDX[feature]
        with PdfPages(base_directory + '/raw_'+feature+'_plot.pdf') as pdf:
            for l in sorted(set(output['cluster'])):
                fig, ax = plt.subplots(1,1,figsize=(8,3))
                cl = grouped_output.get_group(l)
                profile = []
                for i in cl.index:
                    data = train_characterize[cl.loc[i,'index']].T[f]
                    plt.plot(data, c='gray', alpha=0.3)
                    if not median:
                        profile.append(data)
                if l == -1:
                    None
                elif median:
                    plt.plot(scaler.inverse_transform(train_raw_ns[output.loc[medoids[l],'index']]).T[f], c='r')
                else:
                    plt.plot(euclidean_barycenter(profile), c='r')
                for i in range(int(WEEK_MEASUREMENTS/7),WEEK_MEASUREMENTS-2,int(WEEK_MEASUREMENTS/7)):
                    plt.axvline(i, color='k', linestyle='dotted')
                days = ['Mon','Tue','Wed','Thu','Fri','Sat','Sun']
                idx = 0
                for i in range(7):
                    plt.text(i/9+0.16, 0.05, days[idx], transform=plt.gcf().transFigure)
                    idx += 1
                plt.title('Cluster '+str(l), fontsize=16)
                plt.grid(False)
                plt.xlim(0,WEEK_MEASUREMENTS)
                plt.xticks([])

                if feature == 'Home':
                    plt.ylim(-0.1,1.1)
                elif feature == 'SOC':
                    plt.ylim(-5,105)
                elif feature == 'delta_soc':
                    plt.ylim(-12,35)
                elif feature == 'dod':
                    plt.ylim(-5,105)
                elif feature == 'charging_power_level':
                    plt.ylim(-0.1,3.1)
                elif feature == 'charging_energy_kwh':
                    plt.ylim(-5,85)
                elif feature == 'weekly_cycle':
                    plt.ylim(0,13)

                ax.spines['right'].set_visible(False)
                ax.spines['top'].set_visible(False)
                plt.ylabel(FEAT_LABEL[feature])
                pdf.savefig()
                plt.close()
    return True




